
import numpy as	np
import time
import sys
import matplotlib.pyplot as	plt
import os


#--------------------------------#
from FwdBwdNeuralEqV3 import *
from Tx	import *
from eq	import *
from neuralEQ import *
from utils import *
import device


if __name__	== "__main__":
	'''*********************************************
	Neural EQ training for various SNR and misc.
	Training is	performed for snrTrainList,	lossFn,	and	simpleDataTraining.
	If you want	to reduce or increase sweep	cases, modify here.(It's not controlled	by config file now.	)
	*********************************************'''
	#*************************HEADER***********************#
	startTime =	time.time()
	np.random.seed(1)
	args = parsing_def()
	sys.path.insert(0, './config')
	config_module =	__import__('config_{}'.format(args.config))
	cfg= config_module.config
	# mod
	if cfg['train']['mod'] == 'nrz':
		modNum = 2
	elif cfg['train']['mod'] ==	'pam4':
		modNum = 4
	elif cfg['train']['mod'] ==	'pam8':
		modNum = 8
	else:
		sys.exit('invalid modulation')


	delay =	int((cfg['train']['inSize'])/4)
	if (cfg['train']['delayOffset'] is not None):
		delayOffset = cfg['train']['delayOffset']
	else:
		delayOffset = -list(cfg['train']['chSBR']).index(max(cfg['train']['chSBR']))
		print(f"Calculated delay offset is {delayOffset}")
	#******************************************************#


	'''*****************************************************
	Tx and channel define.
	Tx generates random	data according to modulation.
	Channel	adds ISI and noise.
	Channel	is defined 3 times for training, validation	and	test.
	Note that test sets	are	used for both nEQ test and normal equalizer.
	*****************************************************'''

	tx = Tx(mod=cfg['train']['mod'])


	for	selTrainData in	cfg['train']['selTrainDataList']:
		for	lossFn in cfg['train']['lossFnList']:
			print ("")
			print ("")
			print (f"selTrainData: {selTrainData}")
			print (f"lossFn: {lossFn}")
			for	idx, snrTrain in enumerate(cfg['train']['snrTrainList']):

				#@@	Valid sequence for on training 
				chInValid =	tx.run(int(cfg['train']['dataSizeValid']))
				chValid	= Channel(sbr=cfg['train']['chSBR'], snr=snrTrain)
				chOutValid = chValid.run(chIn =	chInValid, flagN=cfg['train']['noiseFlag'])

				#@@	Test sequence for final	evaluation 
				chInTest = tx.run(int(cfg['train']['dataSizeTest']))
				chTest = Channel(sbr=cfg['train']['chSBR'],	snr=snrTrain)
				chOutTest =	chTest.run(chIn	= chInTest,	flagN=cfg['train']['noiseFlag'])



				print("")
				print(f"trainIdx: {idx}	\t snrTrain: {snrTrain}")
				print("")
				#@@	Neural network definition
				#@@	nrzNnOutOne	means network output size is set to	1 for NRZ. But it seems	not	work.
				if	cfg['train']['useFwdBwdNeuralEq']:
					nEQ	= FwdBwdNeuralEq(
									cfg['train']['hiddenStage'], 
									cfg['train']['inSize'],
									delay, 
									cfg['train']['N'], 
									cfg['train']['batchSize'], 
									cfg['train']['mod'],
									)
				else:
					nEQ	= neuralEQ(
								inSize=cfg['train']['inSize'], 
								outSize=cfg['train'] * modNum, 
								mod=cfg['train']['mod'], 
								nnSel=0
								)
				if cfg['train']['usePrunedNeuralEq']:
					nEQ = torch.load(cfg['train']['prunedModelFile'])
					#reset_parameters(nEQ, cfg['train']['hiddenStage'], cfg['train']['inSize'])
				nEQ	= nEQ.to(device.device)

				#@@	Initially, running nEQ with test set.
				simNEQ = simNeuralEQ(
									txDataTrain=None,
									rxDataTrain=None, 
									txDataTest=chInTest, 
									rxDataTest=chOutTest, 
									neuralEQ=nEQ, 
									mod=cfg['train']['mod']
									)

				testLoss, berTest =	simNEQ.evalNeuralEQ(
													lossFn,	
													batchSize=cfg['train']['batchSize'], 
													inSize=cfg['train']['inSize'], 
													outSize=cfg['train']['outSize'], 
													delay=delay+delayOffset, 
													)

				berTestList	= [berTest]
				#testLossList.append(testLoss)
				#testBerList.append(berTest)
				print (f"selTrainData: {selTrainData}")
				print (f"lossFn: {lossFn}")
				print(f"Initialtestloss: {testLoss:e}, testber: {berTest:e} @	selTrainData:{selTrainData}, lossFn:{lossFn}", flush=True)
	
				#@@	Optimizer definition.
				#@@	Adam is	selected.
				#opt = torch.optim.SGD(nEQ.parameters(), lr=lrInit)
				opt	= torch.optim.Adam(
								nEQ.parameters(), 
								lr=cfg['train']['lr'], 
								weight_decay=cfg['train']['weightDecay'])#1e-5)

				#@@	Scheduler definition.
				#@@	gamma=1	means no learning rate change.
				#sch = torch.optim.lr_scheduler.StepLR(opt,	step_size=stepSize,	gamma=gamma)
				#print("")
				#print("----------------NeuralNet parameter----------------")
				#print(nEQ)
				#print(lossFn)
				#print(opt)
				#print("---------------------------------------------------")
				#print("")
			
				summary(
						nEQ, 
						(cfg['train']['batchSize'],cfg['train']['inSize']),	
						batch_size=cfg['train']['batchSize'], 
						device=device.device
						)
			
		
				#@@	Check if pre-simulated fwdBwd is exists
				#@@	If corresponding file(snr, sbr,	mod	...) exists, just load from	file.
				#@@	If not,	run	fwdBwd algorithm 
				if cfg['train']['forceTrainIn']:
					fwdBwdProbFileName = 'caching_data/probNew_less09.list'
					fwdBwdProbChOutFileName	= 'caching_data/chOutNew_less09.list'
					fwdBwdProbChInFileName = 'caching_data/chInNew_less09.list'
				else:
					if (cfg['train']['mismatchSNR']	is not None):
						fwdBwdProbFileName = './caching_data/%s_fwdBwdProb_size%d_%s_snr%ddB.list'%(
																								cfg['train']['mod'],
																								cfg['train']['dataSizeTrain'],
																								cfg['train']['eqSBR'],
																								snrTrain+cfg['train']['mismatchSNR']
																								)
						fwdBwdProbChOutFileName	= './caching_data/%s_fwdBwdProbChOut_size%d_%s_snr%ddB.list'%(
																								cfg['train']['mod'],
																								cfg['train']['dataSizeTrain'],
																								cfg['train']['eqSBR'],
																								snrTrain+cfg['train']['mismatchSNR']
																								)
						fwdBwdProbChInFileName = './caching_data/%s_fwdBwdProbChIn_size%d_%s_snr%ddB.list'%(
																								cfg['train']['mod'],
																								cfg['train']['dataSizeTrain'],
																								cfg['train']['eqSBR'],
																								snrTrain+cfg['train']['mismatchSNR']
																								)

					else:
						fwdBwdProbFileName = './caching_data/%s_fwdBwdProb_size%d_%s_snr%ddB.list'%(
																								cfg['train']['mod'],
																								cfg['train']['dataSizeTrain'],
																								cfg['train']['eqSBR'],
																								snrTrain,
																								)
						fwdBwdProbChOutFileName	= './caching_data/%s_fwdBwdProbChOut_size%d_%s_snr%ddB.list'%(
																								cfg['train']['mod'],
																								cfg['train']['dataSizeTrain'],
																								cfg['train']['eqSBR'],
																								snrTrain,
																								)
						fwdBwdProbChInFileName = './caching_data/%s_fwdBwdProbChIn_size%d_%s_snr%ddB.list'%(
																								cfg['train']['mod'],
																								cfg['train']['dataSizeTrain'],
																								cfg['train']['eqSBR'],
																								snrTrain,
																								)
				if (os.path.exists(fwdBwdProbFileName)):
					#@@	Existing case. Load	from the file.
					print("")
					print("File(%s)	exists,	load from file"%fwdBwdProbFileName)
					print("")
					fwdBwdProbTrain	= loadList(fwdBwdProbFileName)
					#fwdBwdProbChInTrain = loadList(fwdBwdProbChInFileName)
					chOutTrain = loadList(fwdBwdProbChOutFileName)
					chInTrain =	loadList(fwdBwdProbChInFileName)
				else:

					#@@	Train sequence gen
					chInTrain =	tx.run(int(cfg['train']['dataSizeTrain']))
					if (cfg['train']['mismatchSNR']	is not None):
						ch = Channel(sbr=cfg['train']['chSBR'],	snr=snrTrain+cfg['train']['mismatchSNR'])
					else:
						ch = Channel(sbr=cfg['train']['chSBR'],	snr=snrTrain)
					chOutTrain = ch.run(chIn = chInTrain, flagN=cfg['train']['noiseFlag'])

				
					if cfg['train']['selTrainDataList']	== 0:
						#@@	No existing	case. run fwdBwd
						print("")
						print("File(%s)	no exists, excute fwdBwd"%fwdBwdProbFileName)
						print("")
						#@@	Running	fwdBwd with	specified channel output, chOutTrain.
						sweepForTrain =	simSweep(
											chSbr=cfg['train']['chSBR'], 
											eqSbr=cfg['train']['eqSBR'], 
											snrList=[snrTrain],	
											originData=chInTrain, 
											chOutList=[chOutTrain],	
											mod=cfg['train']['mod'], 
											stateGen=True
											)
						fwdBwdBerListTrain,	fwdBwdProbTrain	= sweepForTrain.fwdBwd(fwdBwdLen=['train']['inSize'])
						saveList(fwdBwdProbFileName, fwdBwdProbTrain)
						saveList(fwdBwdProbChOutFileName, chOutTrain)
						saveList(fwdBwdProbChInFileName, chInTrain)
			
				if cfg['train']['selTrainDataList']	== 0:
					fwdBwdProbTrain	= np.array(fwdBwdProbTrain)
				#print (fwdBwdProbTrain.shape)
				#@@	Post-processing	fwdBwd output according	to loss	function.
				#@@	If simpleDataTraining=1, forcing fwdBwd	output to simple TX	data.
				#@@	fwdBwdProbOut =	(modNum)*dataLen
				#@@	TxData = (1)*dataLen. 
				#@@	crossEntropy = (1)*dataLen
				#@@	manualCrossEntropy = (modNum)*dataLen
				#@@	mse	= (modNum)*dataLen
				#@@	According to format	above, it need to be adjusted.

				if cfg['train']['onTheFly']:
					chInTrain =	None
					chOutTrain = None
				
				trainLossHis, validLossHis,	berValidHis	= trainEval(
																nEQ,
																tx,
																chInValid,
																chOutValid,
																cfg['train']['numEpoch'],
																cfg['train']['evalFreq'],
																cfg['train']['mod'],
																cfg['train']['chSBR'],
																cfg['train']['inSize'],
																cfg['train']['outSize'],
																cfg['train']['batchSize'],
																delay+delayOffset,
																lossFn,
																opt,
																int(cfg['train']['dataSizeTrain']),
																snrTrain,
																cfg['train']['noiseFlag'],
																chInTrain,
																chOutTrain,
																trainSnrVariation=cfg['train']['trainSnrVariation'],
																)

	
				#@@	After training,	neural network parameters are saved	for	each snrTrain.
				torch.save(nEQ,	'./results/%s_TRAIN/nEQ_%s_%ddB_simp%d_%s.pt'%(
																	args.name,
																	cfg['train']['mod'],
																	snrTrain,
																	selTrainData,
																	lossFn))
				

				#@@	Finally, running nEQ with test set.
				simNEQ = simNeuralEQ(
									txDataTrain=None,
									rxDataTrain=None, 
									txDataTest=chInTest, 
									rxDataTest=chOutTest, 
									neuralEQ=nEQ, 
									mod=cfg['train']['mod']
									)

				testLoss, berTest =	simNEQ.evalNeuralEQ(
													lossFn,	
													batchSize=cfg['train']['batchSize'], 
													inSize=cfg['train']['inSize'], 
													outSize=cfg['train']['outSize'], 
													delay=delay+delayOffset, 
													)
				berTestList	= [berTest]
				#testLossList.append(testLoss)
				#testBerList.append(berTest)
				print (f"selTrainData: {selTrainData}")
				print (f"lossFn: {lossFn}")
				print(f"Finaltestloss: {testLoss:e}, testber: {berTest:e} @	selTrainData:{selTrainData}, lossFn:{lossFn}", flush=True)
				
				
				
				
				if cfg['train']['plotLoss']:
					plt.figure()
					plt.plot(trainLossHis,'-', label='trainloss')
					#plt.plot(testLossList,'-',	label='testloss')
					plt.grid(True)
					#plt.yscale('log')
					#plt.ylim([1e-9, 1])
					plt.xlabel('epoch')
					plt.ylabel('loss')
					plt.yscale('log')
					plt.legend(loc='best')
					#plt.show()
					plt.savefig('./results/%s_TRAIN/loss_%s_%ddB.png'%(
															args.name,
															cfg['train']['mod'],
															snrTrain))
					#plt.cla()
					
					firBer = None
					dfeBer = None
					if 1:
						plt.figure()
						#plt.plot(trainBerList,'-',	label='trainber')
						plt.plot(berValidHis,'-', label='validber')
						if (firBer is not None):
							plt.plot(firBer*len(berValidHis),'--',label='firber')
						if (dfeBer is not None):
							plt.plot(dfeBer	*len(berValidHis),'--',label='dfeber')
						plt.plot(berTestList*len(berValidHis),'--',label='nnFinalBer')
						#print (dfeBerList*len(trainBerList))
						plt.grid(True)
						plt.yscale('log')
						plt.ylim([1e-4,	1])
						plt.xlabel('epoch')
						plt.ylabel('ber(accuracy)')
						#plt.show()
						plt.legend(loc='best')
						plt.savefig('./results/%s_TRAIN/ber_%s_%ddB.png'%(args.name,cfg['train']['mod'],snrTrain))
						#plt.cla()
		
		
	#if cfg['train']['plotLoss']:
	#	plt.show()
		
		

	timeSim	= (time.time()-startTime)/60. #	Unit: minuite
	print(f"Total simulation time: {timeSim} mins")
